﻿// Copyright (c) .NET Foundation and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Roslynator.CSharp.CSharpFactory;

namespace Roslynator.CSharp.Refactorings.ConvertReturnToIf;

internal abstract class ConvertReturnToIfRefactoring<TStatement> where TStatement : StatementSyntax
{
    protected abstract ExpressionSyntax GetExpression(TStatement statement);

    protected abstract TStatement SetExpression(TStatement statement, ExpressionSyntax expression);

    protected abstract string GetTitle(TStatement statement);

    public async Task ComputeRefactoringAsync(RefactoringContext context, TStatement statement)
    {
        ExpressionSyntax expression = GetExpression(statement);

        if (expression is null)
            return;

        if (CSharpFacts.IsBooleanLiteralExpression(expression.Kind()))
            return;

        SemanticModel semanticModel = await context.GetSemanticModelAsync().ConfigureAwait(false);

        if (semanticModel
            .GetTypeInfo(expression, context.CancellationToken)
            .ConvertedType?
            .SpecialType == SpecialType.System_Boolean)
        {
            context.ThrowIfCancellationRequested();

            context.RegisterRefactoring(
                GetTitle(statement),
                ct => RefactorAsync(context.Document, statement, expression, ct),
                RefactoringDescriptors.ConvertReturnStatementToIf);
        }
    }

    private Task<Document> RefactorAsync(
        Document document,
        TStatement statement,
        ExpressionSyntax expression,
        CancellationToken cancellationToken = default)
    {
        IfStatementSyntax ifStatement = CreateIfStatement(statement, expression)
            .WithTriviaFrom(statement)
            .WithFormatterAnnotation();

        cancellationToken.ThrowIfCancellationRequested();

        return document.ReplaceNodeAsync(statement, ifStatement, cancellationToken);
    }

    private IfStatementSyntax CreateIfStatement(TStatement statement, ExpressionSyntax expression)
    {
        if (expression.IsKind(SyntaxKind.LogicalOrExpression))
        {
            var binaryExpression = (BinaryExpressionSyntax)expression;

            ExpressionSyntax left = binaryExpression.Left;

            if (left?.IsKind(SyntaxKind.LogicalOrExpression) == false)
            {
                ExpressionSyntax right = binaryExpression.Right;

                if (right is not null)
                    return CreateIfStatement(statement, left, TrueLiteralExpression(), right.WithoutTrivia());
            }
        }

        return CreateIfStatement(statement, expression, TrueLiteralExpression(), FalseLiteralExpression());
    }

    private IfStatementSyntax CreateIfStatement(TStatement statement, ExpressionSyntax condition, ExpressionSyntax left, ExpressionSyntax right)
    {
        statement = statement.WithoutLeadingTrivia();

        return IfStatement(
            condition,
            Block(SetExpression(statement, left)),
            ElseClause(
                Block(SetExpression(statement, right))));
    }
}
